import sys
import os
import cv2
import time
from imageio import save
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.ecbsr import ECBSR
import argparse, yaml

parser = argparse.ArgumentParser(description='ECBSR inference')

## yaml configuration files
parser.add_argument('--config', type=str, default=None, help='pre-config file for inference')

## paramters for ecbsr
parser.add_argument('--scale', type=int, default=4, help = 'scale for sr network')
parser.add_argument('--colors', type=int, default=1, help = '1(Y channls of YCbCr), 3(RGB)')
parser.add_argument('--m_ecbsr', type=int, default=4, help = 'number of ecb')
parser.add_argument('--c_ecbsr', type=int, default=8, help = 'channels of ecb')
parser.add_argument('--idt_ecbsr', type=int, default=0, help = 'incorporate identity mapping in ecb or not')
parser.add_argument('--act_type', type=str, default='prelu', help = 'prelu, relu, splus, rrelu')
parser.add_argument('--pretrain', type=str, default=None, help = 'path of pretrained model')

## file path
parser.add_argument('--input_path', type=str, default=None, help = 'path to the input image file')
parser.add_argument('--output_path', type=str, default=None, help = 'path to save the output image file')


if __name__ == '__main__':
    args =  parser.parse_args()
    if args.config:
       opt = vars(args)
       yaml_args = yaml.load(open(args.config), Loader=yaml.FullLoader)
       opt.update(yaml_args)

    device = torch.device('cpu')
   ## definition of model, load, eval
    model_ecbsr =  ECBSR(module_nums=args.m_ecbsr, channel_nums=args.c_ecbsr, with_idt=args.idt_ecbsr, act_type=args.act_type, scale=args.scale, colors=args.colors).to(device)
   # model_ecbsr.load_sate_dict(torch.load("", map_location="cpu"))
    if args.pretrain is not None:
        print("load pretrained model: {}!".format(args.pretrain))
        model_ecbsr.load_state_dict(torch.load(args.pretrain))
    else:
        raise ValueError('the pretrain path is invalud!')
    
    
    model_ecbsr.eval()

    ## read jpeg/png
    bgr_img = cv2.imread(args.input_path)
    if bgr_img is None:
        print("Error: Unable to load iamge.")
    t1 = time.time()
    ## brg-yuv-y-tensor-
    h, w = bgr_img.shape[:2]
    yuv_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2YCrCb).astype(np.float32)
    y = yuv_img[..., 0]
    y_t = torch.from_numpy(y)
    y_t = y_t.unsqueeze(0).unsqueeze(0)

    ## sr, stop use grad
    with torch.no_grad():
        pred = model_ecbsr(y_t).clamp(0, 255)
    pred = pred.cpu().numpy().squeeze(0).squeeze(0) # cpu-numpy-(h, w)
    
    ##upscale yuv
    yuv_img_bicubic = cv2.resize(yuv_img, (w*args.scale, h*args.scale))
    yuv_img_copy = yuv_img_bicubic.copy()
    yuv_img_bicubic = yuv_img_bicubic.astype(np.uint8) # 0-255
    yuv_img_copy[..., 0] = pred # y of sr copy to y of yuv_img_copy
    yuv_img_copy = yuv_img_copy.astype(np.uint8)

    ## yuv-rgb
    yuv_img_copy = cv2.cvtColor(yuv_img_copy, cv2.COLOR_YCrCb2BGR)
    
    t2 = time.time()
    print("time cost::::", t2-t1)
    ## save
    cv2.imwrite(args.output_path, yuv_img_copy)




    



